#!/usr/bin/env python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os, socket
from argparse import ArgumentParser
import warnings


optparser = ArgumentParser(description="train resnet50 with MXNet")
optparser.add_argument("-n", "--n-GPUs", type=int, default=8, help="number of GPUs to use; " +\
                       "default = 8")
optparser.add_argument("-b", "--batch-size", type=int, default=208, help="batch size per GPU; " +\
                       "default = 208")
optparser.add_argument("-e", "--num-epochs", type=int, default=90, help="number of epochs; " +\
                       "default = 90")
optparser.add_argument("-l", "--lr", type=float, default=0.1, help="learning rate; default = 0.1; " +\
                       "IMPORTANT: true learning rate will be calculated as `lr * batch_size/256`")
optparser.add_argument("--no-val", action="store_true",
                       help="if set no validation will be performed")
optparser.add_argument("--no-dali", action="store_true", default=False,
                       help="use default MXNet pipeline instead of DALI")
optparser.add_argument("--data-root", type=str, help="Directory with RecordIO data files", default="/data/imagenet/train-val-recordio-passthrough")
optparser.add_argument("--data-nthreads", type=int, help="number of threads for data loading; default = 40", default=40)
optparser.add_argument("--dtype", type=str, help="Precision, float16 or float32", default="float16")

opts, args = optparser.parse_known_args()

if opts.dtype == "float16":
    n_ch = str(4 - int(opts.no_dali))
else:
    n_ch = str(3)

opts.batch_size *= opts.n_GPUs

opts.lr *= opts.batch_size/256

command = ""
command += "python "+os.path.dirname(__file__)+"/train.py"
command += " --num-layers 50"
command += " --data-train " + opts.data_root + "/train.rec"
command += " --data-train-idx " + opts.data_root + "/train.idx"
if not opts.no_val:
    command += " --data-val " + opts.data_root + "/val.rec"
    command += " --data-val-idx " + opts.data_root + "/val.idx"
command += " --data-nthreads " + str(opts.data_nthreads)
command += " --optimizer sgd --dtype " + opts.dtype
command += " --lr-step-epochs 30,60,80 --max-random-area 1"
command += " --min-random-area 0.05 --max-random-scale 1"
command += " --min-random-scale 1 --min-random-aspect-ratio 0.75"
command += " --max-random-aspect-ratio 1.33 --max-random-shear-ratio 0"
command += " --max-random-rotate-angle 0 --random-resized-crop 1"
command += " --random-crop 0 --random-mirror 1"
command += " --image-shape "+n_ch+",224,224 --warmup-epochs 5"
command += " --disp-batches 20"
command += " --batchnorm-mom 0.9 --batchnorm-eps 1e-5"
if opts.dtype == 'float16':
    command += " --fuse-bn-relu 1"
    command += " --input-layout NHWC --conv-layout NHWC"
    command += " --batchnorm-layout NHWC --pooling-layout NHWC"
    command += " --conv-algo 1 --force-tensor-core 1"
    command += " --fuse-bn-add-relu 1"

command += " --kv-store device"
if not opts.no_dali:
    command += " --use-dali"
    command += " --dali-prefetch-queue 2 --dali-nvjpeg-memory-padding 64"
command += " --lr "+str(opts.lr)
command += " --gpus " + str(list(range(opts.n_GPUs))).replace(' ', '').replace('[', '').replace(']', '')
command += " --batch-size " + str(opts.batch_size)
command += " --num-epochs " + str(opts.num_epochs)


for arg in args:
    command += " " + arg

os.environ['MXNET_UPDATE_ON_KVSTORE'] = "0"
os.environ['MXNET_EXEC_ENABLE_ADDTO'] = "1"
os.environ['MXNET_USE_TENSORRT'] = "0"
os.environ['MXNET_GPU_WORKER_NTHREADS'] = "2"
os.environ['MXNET_GPU_COPY_NTHREADS'] = "1"
os.environ['MXNET_OPTIMIZER_AGGREGATION_SIZE'] = "54"

exit(os.system('/bin/bash -c "'+command+'"'))
